{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "46c66f45-4be3-4674-a870-3849c1048ddb",
   "metadata": {},
   "source": [
    "# GRPO for Math (GSM8k)\n",
    "\n",
    "## Import modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97d9ca00-92a8-4bd3-9b2b-ab8856f5acce",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright (c) Meta Platforms, Inc. and affiliates.\n",
    "# All rights reserved.\n",
    "#\n",
    "# This source code is licensed under the BSD-style license found in the\n",
    "# LICENSE file in the root directory of this source tree.\n",
    "\n",
    "import asyncio\n",
    "import time\n",
    "import uuid\n",
    "from dataclasses import dataclass\n",
    "from typing import Any, Callable\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torchstore as ts\n",
    "from datasets import load_dataset\n",
    "from forge.actors._torchstore_utils import (\n",
    "    get_dcp_whole_state_dict_key,\n",
    "    get_param_prefix,\n",
    ")\n",
    "from forge.actors.generator import Generator as Policy\n",
    "from forge.actors.reference_model import ReferenceModel\n",
    "from forge.actors.replay_buffer import ReplayBuffer\n",
    "from forge.actors.trainer import RLTrainer\n",
    "from forge.cli.config import parse\n",
    "from forge.controller.actor import ForgeActor\n",
    "from forge.controller.provisioner import init_provisioner, shutdown\n",
    "from forge.data.rewards import MathReward, ThinkingReward\n",
    "from forge.observability.metric_actors import get_or_create_metric_logger\n",
    "from forge.observability.metrics import record_metric, Reduce\n",
    "from forge.observability.perf_tracker import Tracer\n",
    "\n",
    "from forge.types import LauncherConfig, ProvisionerConfig\n",
    "from forge.util.ops import compute_logprobs\n",
    "from monarch.actor import endpoint\n",
    "from omegaconf import DictConfig\n",
    "from vllm.transformers_utils.tokenizer import get_tokenizer\n",
    "\n",
    "import os\n",
    "os.environ[\"MONARCH_HOSTMESH_V1\"] = \"1\"\n",
    "os.environ[\"TORCHSTORE_RDMA_ENABLED\"] = \"1\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34d4319f-e6c9-4f4b-9b92-c572de08f0b2",
   "metadata": {},
   "source": [
    "## Define Data Structures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4a25e9d-e1dd-4ea7-a80c-383a2c04656a",
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class Episode:\n",
    "    # TODO: add adtional layer for multi-turn\n",
    "    episode_id: str\n",
    "    request: str\n",
    "    policy_version: int\n",
    "    pad_id: int\n",
    "    request_len: int\n",
    "    response_len: int\n",
    "    target: Any | None = None\n",
    "    # processed data\n",
    "    response: str | None = None\n",
    "    request_tokens: list[int] | None = None\n",
    "    response_tokens: list[int] | None = None\n",
    "    ref_logprobs: torch.Tensor | None = None\n",
    "    reward: float | None = None\n",
    "    advantage: float | None = None\n",
    "\n",
    "    @property\n",
    "    def request_tensor(self):\n",
    "        tensor = torch.tensor(self.request_tokens, dtype=torch.long)\n",
    "        if tensor.shape[0] < self.request_len:  # left pad\n",
    "            diff = self.request_len - tensor.shape[0]\n",
    "            tensor = F.pad(tensor, (diff, 0), value=self.pad_id)\n",
    "        return tensor\n",
    "\n",
    "    @property\n",
    "    def response_tensor(self):\n",
    "        tensor = torch.tensor(self.response_tokens, dtype=torch.long)\n",
    "        if tensor.shape[0] < self.response_len:  # right pad\n",
    "            diff = self.response_len - tensor.shape[0]\n",
    "            tensor = F.pad(tensor, (0, diff), value=self.pad_id)\n",
    "        return tensor\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class Group:\n",
    "    group_id: str\n",
    "    episodes: list[Episode]\n",
    "\n",
    "    @classmethod\n",
    "    def new_group(\n",
    "        cls,\n",
    "        group_id: int,\n",
    "        group_size: int,\n",
    "        request: str,\n",
    "        policy_version: int,\n",
    "        pad_id: int,\n",
    "        request_len: int,\n",
    "        response_len: int,\n",
    "        target: Any = None,\n",
    "    ):\n",
    "        episodes = []\n",
    "        for _ in range(group_size):\n",
    "            episodes.append(\n",
    "                Episode(\n",
    "                    episode_id=str(uuid.uuid4()),\n",
    "                    request=request,\n",
    "                    policy_version=policy_version,\n",
    "                    pad_id=pad_id,\n",
    "                    request_len=request_len,\n",
    "                    response_len=response_len,\n",
    "                    target=target,\n",
    "                )\n",
    "            )\n",
    "        return cls(str(group_id), episodes)\n",
    "\n",
    "\n",
    "def collate(batches: list[list[Episode]]):\n",
    "    inputs = []\n",
    "    targets = []\n",
    "    for batch in batches:\n",
    "        request = [e.request_tensor for e in batch]\n",
    "        request = torch.stack(request)  # [b x s]\n",
    "\n",
    "        response = [e.response_tensor for e in batch]\n",
    "        response = torch.stack(response)  # [b x s]\n",
    "\n",
    "        ref_logprobs = [e.ref_logprobs for e in batch]\n",
    "        ref_logprobs = torch.stack(ref_logprobs).squeeze()  # [b x s]\n",
    "\n",
    "        advantages = [e.advantage for e in batch]\n",
    "        advantages = torch.tensor(advantages).unsqueeze(-1)  # [b x 1]\n",
    "\n",
    "        pad_id = batch[0].pad_id\n",
    "        mask = response != pad_id\n",
    "\n",
    "        input = {\"tokens\": torch.cat([request, response], dim=1)}\n",
    "        target = {\n",
    "            \"response\": response,\n",
    "            \"ref_logprobs\": ref_logprobs,\n",
    "            \"advantages\": advantages,\n",
    "            \"padding_mask\": mask,\n",
    "        }\n",
    "        inputs.append(input)\n",
    "        targets.append(target)\n",
    "    return inputs, targets\n",
    "\n",
    "@dataclass\n",
    "class DatasetActor(ForgeActor):\n",
    "    \"\"\"Actor wrapper for HuggingFace dataset to provide async interface.\"\"\"\n",
    "\n",
    "    path: str = \"openai/gsm8k\"\n",
    "    revision: str = \"main\"\n",
    "    data_split: str = \"train\"\n",
    "    streaming: bool = True\n",
    "    model: str = \"Qwen/Qwen3-1.7B\"\n",
    "\n",
    "    @endpoint\n",
    "    def setup(self):\n",
    "        self._tokenizer = get_tokenizer(self.model)\n",
    "\n",
    "        def gsm8k_transform(sample):\n",
    "            system_prompt = \"\"\"\n",
    "            Put all your scratchpad work between <think> and </think> tags.\n",
    "            Your final answer should be between <answer> and </answer> tags otherwise it will not be scored.\n",
    "            \"\"\"\n",
    "            request: str = sample[\"question\"]\n",
    "            as_chat = [\n",
    "                {\"role\": \"system\", \"content\": system_prompt},\n",
    "                {\"role\": \"user\", \"content\": request},\n",
    "            ]\n",
    "            formatted_request = self._tokenizer.apply_chat_template(\n",
    "                as_chat,\n",
    "                tokenize=False,\n",
    "                add_generation_prompt=True,\n",
    "            )\n",
    "            target: str = sample[\"answer\"]\n",
    "            formatted_target = target.split(\"#### \")[1]\n",
    "            return {\"request\": formatted_request, \"target\": formatted_target}\n",
    "\n",
    "        ds = load_dataset(\n",
    "            self.path, self.revision, split=self.data_split, streaming=self.streaming\n",
    "        )\n",
    "        ds = ds.map(gsm8k_transform)\n",
    "        ds = ds.shuffle()\n",
    "        self._iterator = iter(ds)\n",
    "\n",
    "    @endpoint\n",
    "    async def sample(self) -> dict[str, str] | None:\n",
    "        try:\n",
    "            sample = next(self._iterator)\n",
    "\n",
    "            # Record dataset metrics\n",
    "            record_metric(\"dataset/sample/count_samples_generated\", 1, Reduce.SUM)\n",
    "            record_metric(\n",
    "                \"dataset/sample/avg_sample_len\",\n",
    "                len(sample[\"request\"]),\n",
    "                Reduce.MEAN,\n",
    "            )\n",
    "\n",
    "            return sample\n",
    "        except StopIteration:\n",
    "            return None\n",
    "\n",
    "    @endpoint\n",
    "    async def pad_token(self):\n",
    "        return self._tokenizer.pad_token_id"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "901b3d1d-7eba-4464-b881-48c11ff6e0ef",
   "metadata": {},
   "source": [
    "## Define loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "934aca32-0953-4945-9f99-e7b34804443b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def simple_grpo_loss(\n",
    "    logits: torch.Tensor,\n",
    "    response: torch.Tensor,\n",
    "    ref_logprobs: torch.Tensor,\n",
    "    advantages: torch.Tensor,\n",
    "    padding_mask: torch.Tensor,\n",
    "    beta: float = 0.1,\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Example GRPO Loss Function for RLTrainer\n",
    "    \"\"\"\n",
    "    logprobs: torch.Tensor = compute_logprobs(logits, response)\n",
    "\n",
    "    # Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`\n",
    "    kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1\n",
    "    per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages\n",
    "    per_token_loss = -(per_token_policy_loss - beta * kl)\n",
    "    loss = (\n",
    "        ((per_token_loss * padding_mask).sum(dim=1))\n",
    "        / (padding_mask.sum(dim=1).clamp(min=1.0))\n",
    "    ).mean()\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4f8bbe3-b7ac-4905-b197-f10990f9a104",
   "metadata": {},
   "source": [
    "## Define Reward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "163e98bf-e0f5-4ec3-9690-9839e687f9b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class RewardActor(ForgeActor):\n",
    "    \"\"\"Reward actor that uses a list of scoring functions.\"\"\"\n",
    "\n",
    "    reward_functions: list[Callable]\n",
    "\n",
    "    @endpoint\n",
    "    async def evaluate_response(self, prompt: str, response: str, target: str) -> float:\n",
    "        total_rewards = 0.0\n",
    "        for reward_fn in self.reward_functions:\n",
    "            reward = reward_fn(prompt, response, target)\n",
    "            total_rewards += reward\n",
    "\n",
    "            # Get a name for the reward function (works for classes, functions, lambdas)\n",
    "            reward_fn_name = getattr(\n",
    "                reward_fn, \"__name__\", reward_fn.__class__.__name__\n",
    "            )\n",
    "            # per function reward\n",
    "            record_metric(\n",
    "                f\"reward/evaluate_response/sum_{reward_fn_name}_reward\",\n",
    "                reward,\n",
    "                Reduce.SUM,\n",
    "            )\n",
    "            record_metric(\n",
    "                f\"reward/evaluate_response/avg_{reward_fn_name}_reward\",\n",
    "                reward,\n",
    "                Reduce.MEAN,\n",
    "            )\n",
    "            record_metric(\n",
    "                f\"reward/evaluate_response/std_{reward_fn_name}_reward\",\n",
    "                reward,\n",
    "                Reduce.STD,\n",
    "            )\n",
    "\n",
    "            # avg total reward\n",
    "            record_metric(\n",
    "                \"reward/evaluate_response/avg_total_reward\",\n",
    "                reward,\n",
    "                Reduce.MEAN,\n",
    "            )\n",
    "\n",
    "            # count fn calls\n",
    "            record_metric(\n",
    "                f\"reward/evaluate_response/count_{reward_fn_name}_calls\",\n",
    "                1,\n",
    "                Reduce.SUM,\n",
    "            )\n",
    "\n",
    "        avg_reward = total_rewards / len(self.reward_functions)\n",
    "        return avg_reward\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class ComputeAdvantages(ForgeActor):\n",
    "    \"\"\"Compute advantages for GRPO using reward signals.\"\"\"\n",
    "\n",
    "    @endpoint\n",
    "    async def compute(self, group: Group) -> list[float]:\n",
    "        # TODO: add batch processing\n",
    "        rewards = torch.tensor([[e.reward for e in group.episodes]])\n",
    "        mean = rewards.mean(1, keepdim=True)\n",
    "        std = rewards.std(1, keepdim=True)\n",
    "        advantages = (rewards - mean) / (std + 1e-4)\n",
    "        return advantages.squeeze(0).tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88523484-b414-41db-bd3f-0d8dbf881a85",
   "metadata": {},
   "outputs": [],
   "source": [
    "async def drop_weights(version: int):\n",
    "    print(f\"Dropping weights @ version {version}\")\n",
    "    start_time = time.perf_counter()\n",
    "    prefix = get_param_prefix(version)\n",
    "    matching_keys = await ts.keys(prefix)\n",
    "    # TODO: once we have something like `get_meta()` in torchstore, we can just\n",
    "    # query the type of the object instead of relying on keys.\n",
    "    dcp_key = get_dcp_whole_state_dict_key(version)\n",
    "    if dcp_key in matching_keys:\n",
    "        dcp_handle = await ts.get(dcp_key)\n",
    "        dcp_handle.drop()\n",
    "    for key in matching_keys:\n",
    "        await ts.delete(key)\n",
    "    elapsed = time.perf_counter() - start_time\n",
    "    print(f\"Dropped weights @ version {version}, took {elapsed:.2f} seconds\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95d4fef3-180b-4b7e-8871-ecbe113cde72",
   "metadata": {},
   "source": [
    "## Setup Services"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c811974-cd6b-40ed-a179-4511a7a6c489",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from omegaconf import OmegaConf\n",
    "from forge.cli.config import resolve_hf_hub_paths\n",
    "\n",
    "cfg = OmegaConf.load('apps/grpo/qwen3_1_7b.yaml')\n",
    "cfg = resolve_hf_hub_paths(cfg)\n",
    "OmegaConf.resolve(cfg)\n",
    "\n",
    "group_size = cfg.group_size # 8\n",
    "max_req_tokens = cfg.max_req_tokens # 512\n",
    "max_res_tokens = cfg.max_res_tokens # 512\n",
    "\n",
    "metric_logging_cfg = cfg.get(\"metric_logging\", {\"console\": {\"log_per_rank\": False}})\n",
    "mlogger = await get_or_create_metric_logger()\n",
    "await mlogger.init_backends.call_one(metric_logging_cfg)\n",
    "await ts.initialize(strategy=ts.ControllerStorageVolumes())\n",
    "\n",
    "dataloader, policy, trainer, replay_buffer, compute_advantages, ref_model, reward_actor = await asyncio.gather(\n",
    "    DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),\n",
    "    Policy.options(**cfg.services.policy).as_service(**cfg.policy),\n",
    "    RLTrainer.options(**cfg.actors.trainer).as_actor(\n",
    "        **cfg.trainer, loss=simple_grpo_loss\n",
    "    ),\n",
    "    ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(\n",
    "        **cfg.replay_buffer, collate=collate\n",
    "    ),\n",
    "    ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(),\n",
    "    ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),\n",
    "    RewardActor.options(**cfg.services.reward_actor).as_service(\n",
    "        reward_functions=[MathReward(), ThinkingReward()]\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f2a305f-b1e2-4eac-803c-71bf3225fed7",
   "metadata": {},
   "source": [
    "## Rollout Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1c676fb-2cd6-4c2c-87d4-e1b8cd0b87af",
   "metadata": {},
   "outputs": [],
   "source": [
    "async def continuous_rollouts():\n",
    "    rollout_count = 0\n",
    "    pad_id = await dataloader.pad_token.call_one()\n",
    "    while True:\n",
    "        t = Tracer(\"main_perf/continuous_rollouts\")\n",
    "        t.start()\n",
    "        sample = await dataloader.sample.call_one()\n",
    "        if sample is None:\n",
    "            print(\"Dataloader is empty, exiting continuous rollout\")\n",
    "            return\n",
    "\n",
    "        t.step(\"data_loading\")\n",
    "\n",
    "        prompt, target = sample[\"request\"], sample[\"target\"]\n",
    "        responses = await policy.generate.route(prompt)\n",
    "        # TODO: this shall be part of the responses metadata instead of a separate call\n",
    "        version = await policy.get_version.route()\n",
    "\n",
    "        t.step(\"policy_generation\")\n",
    "\n",
    "        assert (\n",
    "            len(responses) > 0\n",
    "        ), \"Sanity check: Responses should NEVER return empty\"\n",
    "        assert (\n",
    "            version := responses[0].generator_version\n",
    "        ) is not None, \"Response must indicate a version\"\n",
    "        group = Group.new_group(\n",
    "            group_id=rollout_count,\n",
    "            group_size=group_size,\n",
    "            request=prompt,\n",
    "            policy_version=version,\n",
    "            pad_id=pad_id,\n",
    "            request_len=max_req_tokens,\n",
    "            response_len=max_res_tokens,\n",
    "            target=target,\n",
    "        )\n",
    "\n",
    "        input_ids = torch.ones(\n",
    "            (group_size, max_req_tokens + max_res_tokens),\n",
    "            dtype=torch.long,\n",
    "            device=\"cuda\",\n",
    "        )\n",
    "        # Populate episode info and calculate rewards\n",
    "        for i, (episode, response) in enumerate(zip(group.episodes, responses)):\n",
    "            episode.request_tokens = response.prompt_ids\n",
    "            episode.response_tokens = response.token_ids\n",
    "            episode.response = response.text\n",
    "            input_ids[i, :max_req_tokens] = episode.request_tensor\n",
    "            input_ids[i, max_req_tokens:] = episode.response_tensor\n",
    "            episode.reward = await reward_actor.evaluate_response.route(\n",
    "                prompt=prompt, response=response.text, target=target\n",
    "            )\n",
    "\n",
    "        t.step(\"reward_evaluation\")\n",
    "\n",
    "        ref_logprobs = await ref_model.forward.route(\n",
    "            input_ids, max_req_tokens, return_logprobs=True\n",
    "        )\n",
    "        t.step(\"reference_model_calculate_logprobs\")\n",
    "\n",
    "        for i, episode in enumerate(group.episodes):\n",
    "            episode.ref_logprobs = ref_logprobs[i]\n",
    "        del ref_logprobs, input_ids\n",
    "        t.step(\"compute_logprobs\")\n",
    "\n",
    "        # Calculate advantages and add to replay buffer\n",
    "        advantages = await compute_advantages.compute.call_one(group)\n",
    "        for episode, advantage in zip(group.episodes, advantages):\n",
    "            episode.advantage = advantage\n",
    "            await replay_buffer.add.call_one(episode)\n",
    "\n",
    "        # Log metrics\n",
    "        rollout_count += 1\n",
    "        record_metric(\n",
    "            \"main/continuous_rollouts/count_rollout_iterations\", 1, Reduce.SUM\n",
    "        )\n",
    "        t.stop()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57c316dc-11b5-48ea-8b03-e1bb9d9d1f2b",
   "metadata": {},
   "source": [
    "## Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "916a0e79-aded-4ee3-b1a8-db0e772996c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "async def continuous_training():\n",
    "    training_step = 0\n",
    "    restart_tracer = True  # Flag to control when to restart tracer\n",
    "    while True:\n",
    "        # Restart tracer when needed (initial start or after completing a training step)\n",
    "        # Otherwise, we cannot measure time waiting for buffer\n",
    "        if restart_tracer:\n",
    "            t = Tracer(\"main_perf/continuous_training\")\n",
    "            t.start()\n",
    "            restart_tracer = False\n",
    "\n",
    "        batch = await replay_buffer.sample.call_one(\n",
    "            curr_policy_version=training_step\n",
    "        )\n",
    "        if batch is None:\n",
    "            await asyncio.sleep(0.1)\n",
    "        else:\n",
    "            t.step(\"waiting_for_buffer\")\n",
    "\n",
    "            inputs, targets = batch\n",
    "            await trainer.train_step.call(inputs, targets)\n",
    "            training_step += 1\n",
    "            t.step(\"train_step\")\n",
    "\n",
    "            await trainer.push_weights.call(training_step)\n",
    "            t.step(\"push_weights\")\n",
    "\n",
    "            await policy.update_weights.fanout(training_step)\n",
    "            update_task = asyncio.create_task(policy.update_weights.fanout(training_step))\n",
    "            t.step(\"update_weights\")\n",
    "\n",
    "            if training_step >= 2:\n",
    "                await drop_weights(training_step - 1)\n",
    "                t.step(\"drop_weights\")\n",
    "\n",
    "            t.stop()\n",
    "            restart_tracer = True\n",
    "\n",
    "            # Flush metrics every training step to WandB\n",
    "            await mlogger.flush.call_one(training_step)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4542863b-59c5-40bc-896c-6d8d44ada00f",
   "metadata": {},
   "source": [
    "## Run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58194c13-b75e-405d-ab11-18cbe1874d92",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "num_rollout_threads = 1\n",
    "num_training_threads = 1\n",
    "\n",
    "rollout_tasks = [\n",
    "    asyncio.create_task(continuous_rollouts()) for _ in range(num_rollout_threads)\n",
    "]\n",
    "training_task = asyncio.create_task(continuous_training())\n",
    "\n",
    "try:\n",
    "    await asyncio.gather(*rollout_tasks, training_task)\n",
    "except KeyboardInterrupt:\n",
    "    print(\"Training interrupted by user\")\n",
    "    for rollout_task in rollout_tasks:\n",
    "        rollout_task.cancel()\n",
    "    training_task.cancel()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4603b80-1f25-49a1-920e-d24f38dfc687",
   "metadata": {},
   "source": [
    "## Shutdown"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d74e781-c253-4bd0-929f-bd4ad516ba81",
   "metadata": {},
   "outputs": [],
   "source": [
    "await mlogger.shutdown.call_one()\n",
    "await asyncio.sleep(2)\n",
    "\n",
    "await asyncio.gather(\n",
    "    DatasetActor.shutdown(dataloader),\n",
    "    policy.shutdown(),\n",
    "    RLTrainer.shutdown(trainer),\n",
    "    ReplayBuffer.shutdown(replay_buffer),\n",
    "    ComputeAdvantages.shutdown(compute_advantages),\n",
    "    ref_model.shutdown(),\n",
    "    reward_actor.shutdown(),\n",
    ")\n",
    "await shutdown()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "forge",
   "language": "python",
   "name": "forge"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
